// Copyright (c) 2020 Huawei Technologies Co.,Ltd. All rights reserved.
//
// StratoVirt is licensed under Mulan PSL v2.
// You can use this software according to the terms and conditions of the Mulan
// PSL v2.
// You may obtain a copy of Mulan PSL v2 at:
//         http://license.coscl.org.cn/MulanPSL2
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
// NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
// See the Mulan PSL v2 for more details.

use std::os::unix::io::{AsRawFd, RawFd};
use std::sync::atomic::{fence, Ordering};

use anyhow::bail;
use kvm_bindings::__IncompleteArrayField;
use vmm_sys_util::eventfd::EventFd;

use super::{AioCb, AioContext, AioEvent, OpCode, Result};

const IOCB_FLAG_RESFD: u32 = 1;
#[allow(dead_code)]
const IOCB_FLAG_IOPRIO: u32 = 1 << 1;

#[repr(C)]
#[allow(non_camel_case_types)]
#[derive(Default, Clone)]
struct IoEvent {
    data: u64,
    obj: u64,
    res: i64,
    res2: i64,
}

#[repr(C)]
#[allow(non_camel_case_types)]
#[derive(Default)]
struct IoCb {
    data: u64,
    key: u32,
    aio_reserved1: u32,
    aio_lio_opcode: u16,
    aio_reqprio: u16,
    aio_fildes: u32,
    aio_buf: u64,
    aio_nbytes: u64,
    aio_offset: u64,
    aio_reserved2: u64,
    aio_flags: u32,
    aio_resfd: u32,
}

#[repr(C)]
#[allow(non_camel_case_types, dead_code)]
#[derive(Copy, Clone)]
enum IoCmd {
    Pread = 0,
    Pwrite = 1,
    Fsync = 2,
    Fdsync = 3,
    Noop = 6,
    Preadv = 7,
    Pwritev = 8,
}

#[allow(non_camel_case_types)]
pub(crate) enum IoContext {}

pub(crate) struct LibaioContext {
    ctx: *mut IoContext,
    resfd: RawFd,
    events: Vec<AioEvent>,
}

impl Drop for LibaioContext {
    fn drop(&mut self) {
        if !self.ctx.is_null() {
            // SAFETY: self.ctx is generated by SYS_io_setup.
            unsafe { libc::syscall(libc::SYS_io_destroy, self.ctx) };
        }
    }
}

#[repr(C)]
#[derive(Default)]
struct AioRing {
    id: u32,
    nr: u32,
    head: u32,
    tail: u32,

    magic: u32,
    compat_features: u32,
    incompat_features: u32,
    header_length: u32,

    io_events: __IncompleteArrayField<IoEvent>,
}

impl LibaioContext {
    pub fn probe(max_size: u32) -> Result<*mut IoContext> {
        let mut ctx = std::ptr::null_mut();
        // SAFETY: ctx is a valid ptr.
        let ret = unsafe { libc::syscall(libc::SYS_io_setup, max_size, &mut ctx) };
        if ret < 0 {
            bail!("Failed to setup linux native aio context, return {}.", ret);
        }
        Ok(ctx)
    }

    pub fn new(max_size: u32, eventfd: &EventFd) -> Result<Self> {
        let ctx = Self::probe(max_size)?;
        Ok(LibaioContext {
            ctx,
            resfd: eventfd.as_raw_fd(),
            events: Vec::with_capacity(max_size as usize),
        })
    }
}

/// Implements the AioContext for libaio.
impl<T: Clone> AioContext<T> for LibaioContext {
    fn submit(&mut self, iocbp: &[*const AioCb<T>]) -> Result<usize> {
        let mut iocbs = Vec::with_capacity(iocbp.len());
        for iocb in iocbp {
            // SAFETY: iocb is valid until request is finished.
            let cb = unsafe { &*(*iocb) };
            let opcode = match cb.opcode {
                OpCode::Preadv => IoCmd::Preadv,
                OpCode::Pwritev => IoCmd::Pwritev,
                OpCode::Fdsync => IoCmd::Fdsync,
                _ => bail!("Failed to submit aio, opcode is not supported."),
            };
            let aio_buf = match cb.opcode {
                OpCode::Fdsync => 0,
                _ => cb.iovec.as_ptr() as u64,
            };
            iocbs.push(IoCb {
                data: cb.user_data,
                aio_lio_opcode: opcode as u16,
                aio_fildes: cb.file_fd as u32,
                aio_buf,
                aio_nbytes: cb.iovec.len() as u64,
                aio_offset: cb.offset as u64,
                aio_flags: IOCB_FLAG_RESFD,
                aio_resfd: self.resfd as u32,
                ..Default::default()
            });
        }

        // SYS_io_submit needs vec of references.
        let mut iocbp = Vec::with_capacity(iocbs.len());
        for iocb in iocbs.iter() {
            iocbp.push(iocb);
        }

        // SAFETY: self.ctx is generated by SYS_io_setup.
        let ret =
            unsafe { libc::syscall(libc::SYS_io_submit, self.ctx, iocbp.len(), iocbp.as_ptr()) };
        if ret >= 0 {
            return Ok(ret as usize);
        }
        if nix::errno::errno() != libc::EAGAIN {
            bail!("Failed to submit aio, return {}.", ret);
        }
        Ok(0)
    }

    fn get_events(&mut self) -> &[AioEvent] {
        let ring = self.ctx as *mut AioRing;
        // SAFETY: self.ctx is generated by SYS_io_setup.
        let head = unsafe { (*ring).head };
        let tail = unsafe { (*ring).tail };
        let ring_nr = unsafe { (*ring).nr };
        let io_events: &[IoEvent] = unsafe { (*ring).io_events.as_slice(ring_nr as usize) };

        let nr = if tail >= head {
            tail - head
        } else {
            ring_nr - head + tail
        };

        // Avoid speculatively loading ring.io_events before observing tail.
        fence(Ordering::Acquire);
        self.events.clear();
        for i in head..(head + nr) {
            let io_event = &io_events[(i % ring_nr) as usize];
            self.events.push(AioEvent {
                user_data: io_event.data,
                status: io_event.res2,
                res: io_event.res,
            })
        }

        // Avoid head is updated before we consume all io_events.
        fence(Ordering::Release);
        // SAFETY: self.ctx is generated by SYS_io_setup.
        unsafe { (*ring).head = tail };

        &self.events
    }
}
